import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
store_path = '/media/dataset2/huggingface'
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='validation',cache_dir=store_path)

def concatenate_texts(examples):
    return {'text': [' '.join(examples['text'])]}

concat_text = dataset.map(concatenate_texts, batched=True, batch_size=-1)['text'][0]

model_name = "gpt2-xl"  # or "gpt2-medium", "gpt2-large" etc.
tokenizer = GPT2TokenizerFast.from_pretrained(model_name, cache_dir=store_path)
input_ids = tokenizer.encode(concat_text, return_tensors='pt')  # shape: (1, seq_len)

print("Total tokens:", input_ids.shape[1])

model = GPT2LMHeadModel.from_pretrained(model_name,cache_dir=store_path, torch_dtype=torch.bfloat16)
model.to(device)
model.eval()

block_activations = {}
block_gradients = {}

def forward_hook(module, input, output):
    """
    module: self-attention module within gpt2 block(GPT2Attention)
    input:  (hidden_states, ...)
    output: (attn_output, present, (optional attentions))
    """
    attn_output = output[0]  # shape: (batch, seq_len, hidden_dim)
    
    block_activations[module.block_idx] = attn_output
    
    def grad_hook(grad):
        block_gradients[module.block_idx] = grad
    
    attn_output.register_hook(grad_hook)

for idx, block in enumerate(model.transformer.h):
    block.attn.block_idx = idx
    block.attn.register_forward_hook(forward_hook)


max_length = model.config.n_positions
stride = 1024
seq_len = input_ids.size(1)

block_scores = np.zeros(len(model.transformer.h))
block_counts = np.zeros(len(model.transformer.h))

for i in tqdm(range(0, seq_len, stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, seq_len)
    trg_len = end_loc - i
    if trg_len <= 0:
        break
    
    input_ids_chunk = input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids_chunk.clone()

    target_ids[:, :-trg_len] = -100
    
    outputs = model(input_ids_chunk, labels=target_ids)
    loss = outputs.loss
    
    model.zero_grad()
    loss.backward()
    
    for idx in range(len(model.transformer.h)):
        attn_act = block_activations[idx]   # (batch, seq_len, hidden_dim)
        attn_grad = block_gradients.get(idx, None)
        if attn_grad is None:
            continue
        
        gradcam_value = (abs(attn_grad)).sum().item()

        
        block_scores[idx] += gradcam_value
        block_counts[idx] += 1
    
    for idx in range(len(model.transformer.h)):
        block_activations[idx] = None
        block_gradients[idx] = None

mean_scores = block_scores / (block_counts + 1e-9)

for idx, score in enumerate(mean_scores):
    print(f"Block {idx} Grad-CAM score: {score:.4f}")

most_important_block = int(mean_scores.argmax())
print(f"\n** Most important block idx: {most_important_block} (score={mean_scores[most_important_block]:.4f}) **")

